import pickle as pkl
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

try:
    from RMST import RMST
except:
    raise Exception("RMST code available here: https://github.com/barahona-research-group/RMST")

if __name__ == "__main__":

    eps = 1
    rmst_gamma = 0.01

    data = pkl.load(open("hox_gene_expression_data.pkl", "rb"))[:-1]
    exp = data.iloc[:, 4:].to_numpy()

    # remove genes that can be found in almost every neuron or no neuron
    id1 = exp.sum(0) < exp.shape[0] * 0.9
    id2 = exp.sum(0) > 0
    exp = exp[:, id1 & id2]

    ground_truth_1 = dict(enumerate(data["Neurotransmitter"]))
    ground_truth_2 = dict(enumerate(data["Neuron Class"]))
    neuron = data["Neuron"].to_list()

    dist = squareform(pdist(exp, metric="jaccard"))

    plt.figure()
    plt.imshow(dist)
    plt.colorbar()
    plt.savefig("distance_matrix.pdf")

    similarity = 1.0 - dist
    similarity -= np.diag(np.diag(similarity))

    plt.figure()
    plt.imshow(similarity)
    plt.colorbar()
    plt.savefig("similarity_matrix.pdf")

    G = nx.from_numpy_matrix(similarity)
    G = RMST(G, gamma=rmst_gamma, weighted=True)

    plt.figure()
    plt.imshow(nx.adjacency_matrix(G).toarray())
    plt.colorbar()
    plt.savefig("sparsified_similarity_matrix.pdf")

    for i in G:
        G.nodes[i]["neurotransmitter"] = ground_truth_1[i]
        G.nodes[i]["neuron_class"] = ground_truth_2[i]
        G.nodes[i]["neuron"] = neuron[i]

    assert nx.is_connected(G), "graph is not connected"

    nx.write_gpickle(G, "hox_gene_expression_graph.gpickle", protocol=4)
